rm(list=ls())
library(cmdstanr)
library(bayesplot)
library(posterior)
library(mcmcplots)
library(magrittr)
library(multipanelfigure)
library(data.table)
library(tidylog)
library(tidyverse)

setwd("")



# We wish to estimate the mortality rate in a region across six age by sex groups.
# We know the number of deaths in each demographic group and the time at risk (in days).


data <- list(
   n=6,
   days = c(1114,4962,1809,760,2292,732),
   age_g2 = c(0,1,0,0,1,0),
   age_g3 = c(0,0,1,0,0,1),
   sex_g  = c(1,1,1,0,0,0),
   deaths = c(25,156,51,13,60,14)
)

head(data)  

example5 <- cmdstan_model("STAN/example5.stan")


initial_values <- list(

  )


example5_fit <- example5$sample(data=data,
                                seed=123,
                                iter_warmup=2000,
                                iter_sampling=2000,
                                chains = 3,
                                parallel_chains=3,
                               save_warmup = TRUE,
                               thin=1,
                               max_treedepth=15,
                               init=initial_values)
example5_table <- example5_fit$summary() %>% setDT()
example5_table
range(example5_fit$summary()$rhat)








post_draws <- example5_fit$draws() # this is how you store the chain from Stan
#figure2 <- multi_panel_figure(columns = 3, rows = 3, panel_label_type = "upper-roman")
figure1 <- multi_panel_figure(columns = 6, rows = 10, panel_label_type = "upper-roman")
figure1 %<>%
  fill_panel(mcmc_trace(post_draws,pars=c('beta0','beta[1]','beta[2]')), column = 1:6, row = 1:2) %<>%
  fill_panel(mcmc_dens_overlay(post_draws,pars=c('beta0','beta[1]','beta[2]')), column = 1:6, row = 3:4) %<>%
  fill_panel(mcmc_acf_bar(post_draws,pars='beta0'), column = 1:2, row = 5:10) %>%
  fill_panel(mcmc_acf_bar(post_draws,pars='beta[1]'), column = 3:4, row = 5:10) %>%
  fill_panel(mcmc_acf_bar(post_draws,pars='beta[2]'), column = 5:6, row = 5:10) 
#figure1






posterior_sample <- data.frame(example5_fit$draws(format = "matrix",inc_warmup = FALSE))


##### Poisson
N_size <- data$n
log_lik_theta_hat <- rep(0,6)
for(i in 1:N_size){
  log_lik_theta_hat[i] <- dpois(data$deaths[i], mean(posterior_sample[,paste0('mean_deaths.',i,'.')]),log=TRUE)
}




total_log_lik_theta_hat <- sum(log_lik_theta_hat)
mean_log_lik_theta <- mean(posterior_sample[,'log_lik'])
#getwd()
pD <- 2 * (total_log_lik_theta_hat - mean_log_lik_theta)
DIC <- (-2 * total_log_lik_theta_hat) + (2*pD)
DIC3 <- (-2 * mean_log_lik_theta) + (pD)


